import copy
from tensorforce.environments import Environment
import numpy as np
from scipy.optimize import minimize, LinearConstraint
import mpctools as mpc
import casadi as cs
import pyomo.environ as pyo
from pyomo.opt import SolverFactory
from model.model_helper import ModelHelper
from utils.utils_helper import UtilsHelper
from analyze_plotCIS import plot_cis
from analyze_plotCIS import plot_cis_multiple


class RLModelHelper(Environment):  # Create a Child Class of Environment
    def __init__(self):  # Child's constructor overrides the inheritance of the parent's constructor
        self.ini_zone = "cis"  # "hard_constranit" or "cis"
        self.target_zone = "cis"  # "hard_constraint" or "cis"
        self.zone_control = False   # whether include zone tracking reward in reward function design
        self.robust = True          # whether consider the robustness
        self.nx = 2
        self.nu = 1
        self.xlb = np.array((0.0, 345.0))  # Physical hard constraints on states
        self.xub = np.array((1.0, 355.0))
        self.ulb = 285.0  # Physical hard constraints on actions
        self.uub = 315.0
        self.hardConZone = np.array((345.0, 355.0))
        self.targetZone = np.array((348.0, 352.0))
        self.disturbance_bound = np.array((0.1, 2.0))
        self.hrepA = np.load("data/hrepA.npy")  # Half space representation of economic CIS
        self.hrepB = np.load("data/hrepB.npy")
        self.hrepABig_orig = np.load("data/hrepA_rcis_inner_pkg.npy")  # Half space rep of biggest CIS in target zone
        self.hrepBBig_orig = np.load("data/hrepB_rcis_inner_pkg.npy")
        self.hrepABig, self.hrepBBig = self.norm_cis(self.hrepABig_orig, self.hrepBBig_orig)
        self.hrepB_expand_dims = np.expand_dims(self.hrepB, axis=1)
        self.halfspaces = np.hstack((self.hrepA, -self.hrepB_expand_dims))
        self.hrepBBig_expand_dims = np.expand_dims(self.hrepBBig, axis=1)
        self.halfspacesBig = np.hstack((self.hrepABig, -self.hrepBBig_expand_dims))
        self.tol = -2e-2
        # self.current_state = self.reset()  # np.random.uniform(self.xlb, self.xub)
        self.prepare_simulators()  # Treat is as the plant in real life
        self.prepare_ss()
        self.construct_optimization_nlp()
        self.actual_reset = None
        self.disturbance_worst = None
        super().__init__()  #  Make Child Class inherit all the method and properties from its parent.

    def states(self):  # Will override parent's method
        return dict(type='float', shape=(self.nx,), min_value=self.xlb, max_value=self.xub)

    def actions(self):
        return {'Tc': dict(type='float', min_value=self.ulb, max_value=self.uub)}

    def max_episode_timesteps(self):
        return super().max_episode_timesteps()  # Use parent's method

    def close(self):
        super().close()

    def reset(self, num_parallel=None):
        self.timestep=0
        if self.actual_reset == True:
            outCIS = True
            if self.ini_zone == "hard_constraint":
                self.current_state = np.random.uniform(self.xlb, self.xub)
            elif self.ini_zone == "cis":
                while outCIS==True:
                    self.current_state = np.random.uniform(self.xlb, self.xub)
                    check = (np.matmul(self.hrepABig, self.current_state) <= self.hrepBBig + self.tol).all()  # True: initial state is within the CIS
                    if check == True:
                        outCIS = False
            else:
                print("Wrong input")
                quit()
        return self.current_state

    def execute(self, actions):
        """
        This method is actually a part of the "Safety Supervisor".
        In deterministic case, it is the "Model" block in "Safety Supervisor".
        In stochastic case, it is the combination of "Model" + "If xk+1 in CIS"
        Eventually,
            1. the reward is the indicator telling whether the action is safe
            2. the next state
        :param actions:
        :return next state, terminal, reward:
        """
        # Simulate the MDP
        assert actions['Tc'] >= self.ulb and actions['Tc'] <= self.uub
        self.timestep += 1
        # In both deterministic and stochastic cases, the reward compute solely depends on deterministic prediction
        xk = copy.deepcopy(self.current_state)
        self.current_state = self._response(actions)  # Next state - clean. Reward compute doesn't need stochastic state
        reward = self._reward_compute(xk, actions)  # Reward
        if self.robust == True:
            # The assignment of self.disturbance_worst is written in _reward_compute.
            # If the action was not safe, then the worst disturbance was obtained and assigned to self.disturbance_worst
            # If the action was safe, then it is None, and it will be sampled within the bounded disturbance ranges.
            # Theoretically, when agent's mode is independent and reward is negative/x is not in CIS, no need to run the
            # stochastic prediction because it will not be used in the main file. But here we run it anyway for
            # the simplicity of the code.
            self.current_state = self._response_stochastic(xk, actions, self.disturbance_worst)
        terminal = False
        return self.current_state, terminal, reward

    def _response(self, actions):
        xkp1 = self.plant(self.current_state, actions['Tc'], np.zeros(self.nx))
        return xkp1

    def _response_stochastic(self, state, actions, disturbance):
        if disturbance is None:
            disturbance = np.random.uniform(-self.disturbance_bound, self.disturbance_bound)
        xkp1 = self.plant(state, actions['Tc'], disturbance)     # need to impose uncertainties here
        return xkp1

    def _response_stochastic_sym(self, state, actions, disturbance):
        if disturbance is None:
            disturbance = np.random.uniform(-self.disturbance_bound, self.disturbance_bound)
        xkp1 = self.plant(state, actions, disturbance)     # need to impose uncertainties here
        return xkp1

    def _reward_compute(self, state, action):
        # # Target zone
        # if self.current_state[1] >= self.hardConZone[0] and self.current_state[1] <= self.hardConZone[1]:
        #     constraint_penalty = 100
        # else:
        #     constraint_penalty = -100*np.abs(min(self.current_state[1]-self.hardConZone))

        # Physical constraint. Since the agent has very large tendency to prescribe a bad policy, so we give a larger
        # reward when it prescribes a good policy.
        if (self.current_state < self.xlb).any() or (self.current_state > self.xub).any():
            constraint_penalty = 0 #-1000
        else:
            constraint_penalty = 0 #10000
        # Largest CIS
        # Robustness case ------------------------------------------------------------------
        # Need to check whether self.current_state is in or out of CIS with the presence of uncertainty.
        # The CIS we are using right now is the robust one. It means there is an action to maintain the system within
            # the set with the presence of uncertainty.
        # In prediction step, an action could be a safe one without the uncertainty, but could be the unsafe one after
            # adding the uncertainty onto the deterministic state. Then this action is not the safe action. We need to
            # find the action always drives the system within the CIS even the worst uncertainty is presented.
        # To compare the robust case with the deterministic one, the difference is the Safety supervisor. Previously,
            # when xk+1 is in CIS, we can make sure that when we apply uk to the Environment, xk+1 is also in CIS. Now,
            # when xk+1 is in CIS in Safety Supervisor, we do not know if the response of the Environment will be the
            # same or not. Hence, for "If xk+1 in CIS" block in the schematic diagram, we need to improve it.
        if self.robust == True:
            # Scipy optimization -----------------------------------------------------------
            # res = self.checkInOrOut(state, action)
            # check = res["fun"]
            # w_opt = res["x"]
            # Ipopt optimization -----------------------------------------------------------
            # res_nlp = self.checkInOrOut_nlp(state, action['Tc'])
            # check_nlp = res_nlp['f'].elements()[0]
            # MILP optimization ------------------------------------------------------------
            res = self.checkInOrOut_milp(state, action)
            check = pyo.value(res.cost_func)
            w_opt = [pyo.value(res.w[1]), pyo.value(res.w[2])]
            # for i in range(self.hrepABig.shape[0]):
            #     if pyo.value(res_milp.binary[i+1]) == 0:
            #         z = self.hrepBBig[i] + self.tol - np.matmul(self.hrepABig[i], self._response_stochastic(state, action, w_opt))
            #         print('The difference between z from MILP and z calcualted offline is', check_milp-z)
            # check_sign = check_milp*check
            # if check_sign < 0:
            #     print('LP with max and MILP produce different result')
            #     if check < 0:
            #         print(f'LP with max shows UNSAFE action with J={check} and w={res["x"]}, yet MILP shows SAFE action with J={check_milp} and w={w_opt}')
            #     else:
            #         print(f'MILP shows UNSAFE action with J={check_milp} and w={w_opt}, yet LP with max shows SAFE action with J={check} and w={res["x"]}')
            # ------------------------------------------------------------------------------
            if check >= 0:
                inCIS = True
                self.disturbance_worst = None
                # print('The action is safe')
            else:
                inCIS = False
                self.disturbance_worst = w_opt
                # print('The action is not safe, it is possible that xk+1 is outside of CIS when w is present. The worst disturbance is', self.disturbance_worst)
            # Till now, we do not use xk+1 at all. We need to make sure the action is safe, then we say the system is
            # driven to xk+1 safely.
        else:
            if (np.matmul(self.hrepABig, self.current_state) <= self.hrepBBig + self.tol).all():
                inCIS = True
            else:
                inCIS = False

        if inCIS==True:
            largeCIS_penalty = 10000 #900
            if self.zone_control == True:
                if self.current_state[1] > 352.0:
                    largeCIS_penalty -= (352.0-self.current_state[1])**2 * 300.0
                elif self.current_state[1] < 348.0:
                    largeCIS_penalty -= (348.0-self.current_state[1])**2 * 300.0
                else:
                    largeCIS_penalty += 0
        else:
            largeCIS_penalty = -1000 # -100 * np.abs(dis)
        # Zone control
        if self.target_zone == "hard_constraint":
            zone_penalty = 0
        elif self.target_zone == "cis":
            if (np.matmul(self.hrepA,self.current_state) <= self.hrepB + self.tol).all():
                zone_penalty = 0#500
            else:
                # res = self.find_closest(self.halfspaces, self.current_state)
                # dis = res["fun"]
                zone_penalty = 0  # -100*np.abs(dis)
        else:
            print("Wrong input of Target Zone")
            quit()
        reward = zone_penalty + largeCIS_penalty + constraint_penalty # + -self.current_state[0]
        return reward

    def prepare_simulators(self):
        self.model_helper = ModelHelper([self.nx], [self.nu])  # The instance model_helper could be self. or local.
                                                                   # For this method, depending on if it will be used
                                                                   # in other methods
        self.plant = self.model_helper.F  # Excluding the time information

    def prepare_ss(self):
        self.utils = UtilsHelper()  # The instance utils could be self. or local
                                # for this method, depending on if it will be used
                                # in other methods
        self.xss, self.uss = self.utils.prepare_ss()

    def norm_cis(self, Ha, Hb):
        magnitude = np.linalg.norm(Ha, axis=1)
        magnitude_inverse = 1 / magnitude
        transform_matrix = np.diag(magnitude_inverse)
        Ha_norm = np.matmul(transform_matrix, Ha)
        Hb_norm = np.matmul(transform_matrix, Hb)
        return Ha_norm, Hb_norm

    def dist(self, x, p):
        return np.linalg.norm(x - p)

    def find_closest(self, halfspaces, p):
        return minimize(
            self.dist,
            np.zeros(self.nx),
            args=(p,),
            constraints=[LinearConstraint(halfspaces[:, :-1], -np.inf, -halfspaces[:, -1])]
            )

    def objective(self, w, p, u):
        result = np.matmul(self.hrepABig, self._response_stochastic(p, u, w)) - self.hrepBBig - self.tol
        orig_max = np.max(result)
        k=1000
        log_sum_max = 1/k * np.log(np.sum(np.exp(k*result)))
        boltzmann_max = np.sum(result*np.exp(k*result))/np.sum(np.exp(k*result))
        # print(f'Different max methods produce {orig_max}, {log_sum_max}, {boltzmann_max}')
        return -orig_max

    def checkInOrOut(self, p, u):
        return minimize(
            self.objective,
            np.zeros(self.nx),
            args=(p, u),
            constraints=[LinearConstraint(np.eye(self.nx), -self.disturbance_bound, self.disturbance_bound)]
            )

    def objective_nlp(self, w, p, u):
        result = np.matmul(self.hrepABig, self._response_stochastic_sym(p, u, w)) - self.hrepBBig - self.tol
        orig_max = cs.mmax(result)
        k=1000
        log_sum_max = 1 / k * cs.log(cs.sum1(cs.exp(k * result)))
        boltzmann_max = cs.sum1(result*cs.exp(k*result))/cs.sum1(cs.exp(k*result))
        max_value = boltzmann_max
        return -max_value

    def construct_optimization_nlp(self):
        w = cs.SX.sym('w', self.nx)
        p = cs.SX.sym('p', self.nx)
        u = cs.SX.sym('u', self.nu)
        J = self.objective_nlp(w, p, u)

        # 1. Decision variables
        dvar = cs.vertcat(w, p, u)

        # 2. Creat optimization problem - minimizing f by changing x while satisfying g
        problem = {'x': dvar, 'f': J}

        # NLP solver options
        opts = {}
        # opts["expand"] = True
        opts["ipopt.print_level"] = 0
        # opts["verbose"] = False
        # opts["ipopt.linear_solver"] = "mumps"  # use ma27 for faster and more robust computer. This needs to be installed on the system
        # opts["hessian_approximation"] = "limited-memory"  # Use it if the problem is too big. Yet the accuracy of the hessian is low

        # Creat optimization problem for IPOPT
        self.nlp = cs.nlpsol('optimization', 'ipopt', problem, opts)
        return self.nlp

    def checkInOrOut_nlp(self, p, u):
        # 3. Constraints
        lcon = cs.vertcat(-self.disturbance_bound, p, u)
        ucon = cs.vertcat(self.disturbance_bound, p, u)

        res = self.nlp(lbx=lcon, ubx=ucon, x0=cs.vertcat(np.zeros(self.nx), p, u))
        Jopt = res['f'].full().ravel()
        sol = res['x'].full().ravel()
        return res

    def checkInOrOut_milp(self, p, u):
        Nc = self.hrepABig.shape[0]

        model = pyo.ConcreteModel()
        model.w_size = pyo.RangeSet(2)
        model.z_size = pyo.RangeSet(1)
        model.binary_size = pyo.RangeSet(Nc)

        # Todo: add initial guess
        model.w = pyo.Var(model.w_size)
        model.z = pyo.Var(model.z_size)
        model.binary = pyo.Var(model.binary_size, within=pyo.Binary)

        def enforce_w_ub(model, i):
            return model.w[i] <= self.disturbance_bound[i-1]
        model.w_ub = pyo.Constraint(model.w_size, rule=enforce_w_ub)

        def enforce_w_lb(model, i):
            return model.w[i] >= -self.disturbance_bound[i-1]
        model.w_lb = pyo.Constraint(model.w_size, rule=enforce_w_lb)

        def enforce_active_constraint(model):
            return pyo.summation(model.binary) == Nc-1
        model.active_constraint = pyo.Constraint(rule=enforce_active_constraint)

        def mat_mul(v1, v2):
            return sum(v1[j] * 0.1*(100.0/100.0) * v2[j+1] for j in range(self.nx))
        def enforce_z_bound(model, i):
            bound = self.hrepBBig[i-1] + self.tol - np.matmul(self.hrepABig[i-1], self._response_stochastic(p, u, np.zeros(self.nx)))
            matrix_multiplication = mat_mul(self.hrepABig[i-1], model.w)
            return model.z[1] + 100*model.binary[i] + matrix_multiplication >= bound
        model.z_bound = pyo.Constraint(model.binary_size, rule=enforce_z_bound)

        def build_cost_func(model):
            return model.z[1]
        model.cost_func = pyo.Objective(rule=build_cost_func, sense=pyo.minimize)

        solver = SolverFactory('gurobi')
        solver.solve(model)
        return model

    def backup(self, x, unsafe):
        # Discretize action space
        du = 100
        u = np.linspace(self.ulb, self.uub, du)
        # Select safe actions
        good_actions = []
        safe_actions = []
        for idx, item in enumerate(u):
            # For all actions, check if next state is in CIS or not. Do not update RL at all.
            self.actual_reset = False
            states_redundant = self.reset()
            self.current_state = copy.deepcopy(x)
            actions = {'Tc': item}
            states, terminal, reward = self.execute(actions=actions)
            # print(idx, item, states, reward)
            if reward >= 0:
                good_actions.append(actions)
                safe_actions.append(item)
        # Pick the safe one closest to the unsafe one todo: may need an optimization
        distance = np.abs(np.array(safe_actions) - unsafe['Tc'])
        index = np.where(distance == distance.min())[0][0]
        safe = safe_actions[index]
        print('Using backup plan. The unsafe action is', unsafe['Tc'], 'and the safe action is', safe)
        # Need to reset self.current_state back to previous_state
        self.current_state = copy.deepcopy(x)
        return {'Tc': safe}

    def backup_findSafeU(self, x):
        # Record episode experience
        episode_states = list()
        episode_actions = list()
        episode_terminal = list()
        episode_reward = list()
        # Discretize action space
        du = 100
        u = np.linspace(self.ulb, self.uub, du)
        # Select safe actions
        good_actions = []
        for idx, item in enumerate(u):
            # For all actions, check if next state is in CIS or not. Do not update RL at all.
            self.actual_reset = False
            states_redundant = self.reset()
            self.current_state = copy.deepcopy(x)
            actions = {'Tc': item}
            states, terminal, reward = self.execute(actions=actions)
            print(idx, item, states, reward)
            if reward >= 0:
                good_actions.append(actions)
                episode_states.append(x)
                episode_actions.append(actions)
                episode_terminal.append(True)
                episode_reward.append(reward)
        return episode_states, episode_actions, episode_terminal, episode_reward

    def backup_findClosestU(self, x, good_actions, unsafe):
        safe_actions = []
        for idx, item in enumerate(good_actions):
            safe_actions.append(item['Tc'])
        # Pick the safe one closest to the unsafe one todo: may need an optimization
        distance = np.abs(np.array(safe_actions) - unsafe['Tc'])
        index = np.where(distance == distance.min())[0][0]
        safe = safe_actions[index]
        print('Using backup plan. The unsafe action is', unsafe['Tc'], 'and the safe action is', safe)
        # Need to reset self.current_state back to previous_state
        self.current_state = copy.deepcopy(x)   # todo: maybe redundant
        return {'Tc': safe}

    def examine_state(self, x):
        # Deterministic case
        if (np.matmul(self.hrepABig, x) < self.hrepBBig + self.tol).all():
            inCIS_deter = True
            print('Deterministic: state is in CIS')
        else:
            inCIS_deter = False
            print('Deterministic: state is outside of CIS')

        # # Stochastic case
        # if self.robust == True:
        #     res = self.checkInOrOut(x)
        #     check = res["fun"]
        #     if check >= 0:
        #         inCIS_sto = True
        #         print('Stochastic: state is in CIS')
        #     else:
        #         inCIS_sto = False
        #         print('Stochastic: state is outside of CIS')

    def plot_state(self, x):
        lbx_zn, ubx_zn = 345, 355
        lbx, ubx = self.xlb, self.xub
        Ha_norm, Hb_norm = self.hrepABig, self.hrepBBig
        a1, b1, c1, d1, x2S1 = plot_cis(lbx_zn, ubx_zn, Ha_norm, Hb_norm, lbx, ubx)
        cis_tags_list = ['rcis_inner']

        if x.ndim == 1:
            x = x.reshape((-1, x.shape[0]))
        plot_cis_multiple([a1], [b1], c1, d1, [x2S1], lbx_zn, ubx_zn, lbx, ubx,
                          x, x,
                          title='Failed state', cis_tags=cis_tags_list,
                          traj_tags=['failed state', 'failed state'])